{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "import os, sys\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.linear_model import Ridge\n",
    "\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "tqdm.pandas()\n",
    "sys.path.append('..')\n",
    "import assign_loop_type\n",
    "from assign_loop_type import write_loop_assignments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Code to reproduce DegScore from Kaggle datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_input(df, window_size=1, pad=10, seq=True, struct=True, ensemble_size=0):\n",
    "    '''Creat input/output for regression model for predicting structure probing data.\n",
    "    Inputs:\n",
    "    \n",
    "    dataframe (in EternaBench RDAT format)\n",
    "    window_size: size of window (in one direction). so window_size=1 is a total window size of 3\n",
    "    pad: number of nucleotides at start to not include\n",
    "    seq (bool): include sequence encoding\n",
    "    struct (bool): include bpRNA structure encoding\n",
    "    \n",
    "    Outputs:\n",
    "    Input array (n_samples x n_features): array of windowed input features\n",
    "    feature_names (list): feature names\n",
    "    \n",
    "    '''\n",
    "    #MAX_LEN = 68\n",
    "    BASES = ['A','U','G','C']\n",
    "    STRUCTS = ['H','E','I','M','B','S']\n",
    "    \n",
    "    inpts = []\n",
    "    labels = []\n",
    "\n",
    "    feature_kernel=[]\n",
    "    if seq:\n",
    "        feature_kernel.extend(BASES)\n",
    "    if struct:\n",
    "        feature_kernel.extend(STRUCTS)\n",
    "\n",
    "    feature_names = ['%s_%d' % (k, val) for val in range(-1*window_size, window_size+1) for k in feature_kernel]\n",
    "    \n",
    "    for i, row in tqdm(df.iterrows(), desc='Encoding inputs', total=len(df)):\n",
    "        MAX_LEN = len(row['sequence'])-39 #68 for RYOS-I\n",
    "        \n",
    "        arr = np.zeros([MAX_LEN,len(feature_kernel)])\n",
    "        \n",
    "        if ensemble_size > 0: # stochastically sample ensemble\n",
    "            ensemble = get_ensemble(row['sequence'], n=ensemble_size)\n",
    "        else: # use MEA structure\n",
    "            ensemble = np.array([list(row['predicted_loop_type'])])\n",
    "\n",
    "        for index in range(pad,MAX_LEN):\n",
    "            ctr=0\n",
    "\n",
    "            #encode sequence\n",
    "            if seq:\n",
    "                for char in BASES:\n",
    "                    if row['sequence'][index]==char:\n",
    "                        arr[index,ctr]+=1\n",
    "                    ctr+=1\n",
    "\n",
    "            if struct:\n",
    "                loop_assignments = ''.join(ensemble[:,index])\n",
    "                for char in STRUCTS:\n",
    "                    prob = loop_assignments.count(char) / len(loop_assignments)\n",
    "                    arr[index,ctr]+=prob\n",
    "                    ctr+=1\n",
    "                    \n",
    "        # add zero padding to the side\n",
    "        padded_arr = np.vstack([np.zeros([window_size,len(feature_kernel)]),arr[pad:], np.zeros([window_size,len(feature_kernel)])])\n",
    "\n",
    "        for index in range(pad,MAX_LEN):\n",
    "            new_index = index+window_size-pad\n",
    "            tmp = padded_arr[new_index-window_size:new_index+window_size+1]\n",
    "            inpts.append(tmp.flatten())\n",
    "            labels.append('%s_%d' % (row['id'], index))\n",
    "            \n",
    "    return np.array(inpts), feature_names, labels\n",
    "\n",
    "def encode_output(df, data_type='reactivity', pad=10):\n",
    "    '''Creat input/output for regression model for predicting structure probing data.\n",
    "    Inputs:\n",
    "    \n",
    "    dataframe (in EternaBench RDAT format)\n",
    "    data_type: column name for degradation\n",
    "    window_size: size of window (in one direction). so window_size=1 is a total window size of 3\n",
    "    pad: number of nucleotides at start to not include\n",
    "    \n",
    "    Outputs:\n",
    "    output array (n_samples): array of reactivity values\n",
    "    \n",
    "    '''\n",
    "    #MAX_LEN = 68\n",
    "    \n",
    "    outpts = []\n",
    "    labels = []\n",
    "    # output identity should be in form id_00073f8be_0\n",
    "\n",
    "    for i, row in df.iterrows():\n",
    "        MAX_LEN = len(row['sequence'])-39\n",
    "        \n",
    "        for index in range(pad,MAX_LEN):\n",
    "            outpts.append(row[data_type][index])\n",
    "            labels.append('%s_%d' % (row['id'], index))\n",
    "            \n",
    "            \n",
    "    return outpts, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "kaggle_train = pd.read_json('train.json',lines=True)\n",
    "kaggle_train = kaggle_train.loc[kaggle_train['SN_filter']==1]\n",
    "\n",
    "kaggle_test = pd.read_json('test.json',lines=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Encode data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Max. expected accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d6ee9daf5dfe40b189ffe14079cec14c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Encoding inputs', max=1589, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b583d18613ff4c25898ea1feec9e2eef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Encoding inputs', max=3634, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "inputs_train, feature_names, _ = encode_input(kaggle_train, window_size=12)\n",
    "inputs_test, _, test_labels = encode_input(kaggle_test, window_size=12)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Visualize encoding for an example nucleotide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 0, 'window position')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAACiCAYAAABxn2koAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQdklEQVR4nO3de7RcZXnH8e/PpJhA5BISaYCEIA13KeUctALK1RYpLKC1QmDJpS1UxVJbU0vRIixkqV2IWg2yImLCHYpiKaSsUDUEbbicA4EAAQpCuCSFJFwj9/D0j/0OnTWZk+w5Z/aemZ3fZ61ZZ99m72efOedZ7/vu931HEYGZWdHe0+kAzGzD4GRjZqVwsjGzUjjZmFkpnGzMrBRONmZWCicb6wmSzpZ0eVqeImm1pFGdjsvyc7LpQZKekPSmpAkN2xdJCklT0/rsdNzqutcxded4rWHf98u/m9ZFxJMRMS4i1nQ6FsvPyaZ3PQ5Mr61I+iAwtslx/5L+MWuva+r2HdGw7/NFB20bLieb3nUZcELd+onApUVcSNJ7JJ0h6TFJqyRdK2l82jc1laZOlPSkpJWSvlz33lGSzkzvfUXSoKTJad8+ku6S9FL6uU/d+7aXdGt6zy3AhLp9tWuOTuvzJZ0r6dfp+Hn1pT5JJ0hammL/51SqO6SI35UNzcmmd90ObCppl9R2cQxweUHXOh04Ctgf2Bp4AZjZcMx+wE7AwcBZknZJ2/+erAR2GLAp8BfAqylZ3QT8K7AlcAFwk6Qt0/uuBAbJksy5ZMl0XY4DTgbeD2wEzACQtCtwIXA8MAnYDNimpbu39ogIv3rsBTwBHAJ8Bfg6cChwCzAaCGBqOm428DrwYnqtbDjH6rp9LwKnDHG9JcDBdeuTgLfS9aama25bt/9O4Ni0/DBwZJNzfhq4s2HbQuAkYArwNrBJ3b4rgcvTcu2ao9P6fOArdcd+Drg5LZ8FXFW3b2PgTeCQTn+OG9pr9DBzlHWHy4AFwPYMXYU6PyK+MsS+oyLiv3JcZzvgeknv1G1bA2xVt/6/dcuvAuPS8mTgsSbn3BpY2rBtKVmpY2vghYj4bcO+yeuIcajrbw08VdsREa9KWrWO81hBXI3qYRGxlKyh+DDgpwVe6ingExGxed1rTEQ8k/O9OzTZvowsidWbAjwDLAe2kLRJw77hWA5sW1uRNJas2mYlc7LpfX8JHNRQCmi3i4DzJG0HIGmipCNzvvdi4FxJ05TZI7XLzAV2lHScpNHpkfyuwI0piQ4A50jaSNJ+wBHDjP064IjUGL0RcA6gYZ7LRsDJpsdFxGMRMTDMt/9HQz+b64c47rvADcA8Sa+QNU5/OOc1LgCuBeYBLwM/AsZGxCrgcOCLwCrgS8DhEbEyve+4dI3nga8yzCdtEfEA8DfA1WSlnFeA54A3hnM+Gz6lRjOzDYKkcWSN4dMi4vFOx7MhccnGKk/SEZI2Tm1A5wOLyZ7GWYmcbGxDcCRZg/QyYBrZY3kX6UvmapSZlcIlGzMrRaGd+iTlKjb19fUVGcY6DQ4O5jqukzG2W7ffc974WtHtn1+777mV+233tSOiadeCQqtReZNNJ6tyUr4uF1Wqbnb7PeeNrxXd/vm1+55bud8Crt30hC1VoyQdnUbb7tyesMxsQ9Fqm8104FfAsQXEYmYVljvZpM5Q+5J1j3eyMbOWtFKyOYps2P4jwPOS9mp2kKRTJQ1IGm4XejOroNwNxJJuAr4TEbdIOh2YHBH/sJ73uIG4C3X7PbuBeOS6sYE4V7JJo3SfJhvAFsCo9HO7dfXEdLLpTt1+z042I9eNySZvNeqTwKURsV1ETI2IyWTzqOzXrgDNrNryJpvpQOP0Az8hmwbAzGy9Cu3U19/fHwMD5bcTt1IsbKHNarjhjEgR8XV79cjV6pHr1N9Df38/AwMDI+/UZ2Y2XLmSTfqenvsbtp0taUYxYZlZ1bhkY2alcLIxs1I42ZhZKfImm6Gaq9faXj9cYcWKFcOPzMwqJW+yWQVs0bBtPLCy8cCImBUR/RHRP3HixJHGZ2YVkSvZRMRqYLmkgwHSl8IfSjbdhJnZerUyLegJwExJ30rr50REs+9wNjNbS+5kExEPAgcWGMt6dbJ3Z7f3GC0ivk4ODuyUTvUob/fvpht/134aZWalcLIxs1K0Mi3oVpKulPQbSYOSFko6usjgzKw68o6NEvAzYEFEfCAi+sjmId62yODMrDryNhAfBLwZERfVNkTEUuB7hURlZpWTtxq1G3B3ngPdg9jMmhlWA7GkmZLulXRX4z73IDazZvImmweAd7+6JSJOAw4GnE3MLJe8yeYXwBhJn63btnEB8ZhZReUdGxVkX1K3v6THJd0JzAH+scjgzKw6WhmusJyCvna3KpNMF8FDNNqj3b/HKv1uyuIexGZWilZ6EP+upKslPSbpQUlzJe1YZHBmVh2t9CC+HpgfETtExK7AmcBWRQZnZtWRt83mQOCthh7Ei4oJycyqKG81andgMM+B7kFsZs20vYHYPYjNrJlWehD3FRmImVVbKz2I3yvplNoGSXtL2r+YsMysalrpQXw08PH06PsB4GxgWYGxmVmFtNKDeBnwqSKCcG/MofXC76bdk38XoRd+j1XnHsRmVoq8nfpWN6yfJOn7xYRkZlXkko2ZlcLJxsxKkbeBeKyk+uEJ44Ebmh0o6VTgVIApU6aMLDozq4y8JZvXImLP2gs4a6gD3YPYzJpxNcrMSuFkY2alcLIxs1LkaiCOiHEN67OB2QXEYz3IvXMtD5dszKwUucdGAUhaAyyu23R1RHyjvSGZWRW1lGxIj8ALicTMKs3VKDMrRavJZqykRXWvYxoP8BzEZtZM26tRETELmAXQ39/vxxRmBrgaZWYlcbIxs1K0Wo1qHP19c0Sc0c6AzKyaWko2ETGqleMHBwfbOj9tET1V2z1/rnvTWjcoYl7okf5tuxplZqUYdrJpnJfYzGxdXLIxs1I42ZhZKVp9GrVe9XMQm5nVtD3Z1PcgluRHM2YGuBplZiVxsjGzUjjZmFkpht1m0zgvcTN9fX0MDAwM9xKlcI9fq6Ju/Lt2ycbMSrHeZCMpJF1Wtz5a0gpJNxYbmplVSZ6SzW+B3SWNTesfB54pLiQzq6K81aj/BP4kLU8HriomHDOrqrzJ5mrgWEljgD2AO4Y60HMQm1kzuZJNRNwHTCUr1cxdz7GzIqI/IvonTpw48gjNrBJaefR9A3A+cACwZSHRmFlltZJsLgFeiojFkg4oKB4zq6jcySYinga+W2AsZlZh6002zXoKR8R8YH4B8ZhZRbkHsZmVIneykbQmfeXuvZLulrRPkYGZWbW00kD87lfvSvpj4OvA/oVEZWaVM9xq1KbAC+0MxMyqrZWSTe3bMMcAk4CDignJzKqolZLNaxGxZ0TsDBwKXKomX7vn4Qpm1sywqlERsRCYAKw1HsHDFcysmWElG0k7A6OAVe0Nx8yqajhtNgACToyINQXEZGYV1MpwhVFFBmJm1db2L6krUpP26BFr98TQRcTY7bpxcu1GG9rnUsRnMtLfoYcrmFkpWko2kr4s6QFJ96WhCx8uKjAzq5bc1ShJHwEOB/aKiDckTQA2KiwyM6uUVtpsJgErI+INgIhYWUxIZlZFrVSj5gGTJT0i6UJJTQdhugexmTWTO9lExGqgDzgVWAFcI+mkJse5B7GZraWlR9+pE998YL6kxcCJwOz2h2VmVdPK5Fk7SZpWt2lPYGn7QzKzKmqlZDMO+J6kzYG3gUfJqlRmZuulInt/SlrB2qWfCUBVnmRV5V6qch/ge+m07SKiaWNtocmm6QWlgYjoL/WiBanKvVTlPsD30s08XMHMSuFkY2al6ESymdWBaxalKvdSlfsA30vXKr3Nxsw2TK5GmVkpnGzMrBSlJRtJh0p6WNKjks4o67pFkPSEpMVpTp+BTsfTCkmXSHpO0v1128ZLukXS/6SfW3QyxryGuJezJT2TPptFkg7rZIx5SJos6ZeSlqT5ov42be/Jz2UopSQbSaOAmcAngF2B6ZJ2LePaBTowfY9Wr/WDmE32vV/1zgB+HhHTgJ+n9V4wm7XvBeDb6bPZMyLmlhzTcLwNfDEidgH+EDgt/X/06ufSVFklmw8Bj0bEbyLiTeBq4MiSrm11ImIB8HzD5iOBOWl5DnBUqUEN0xD30nMiYnlE3J2WXwGWANvQo5/LUMpKNtsAT9WtP5229aoA5kkalFSF8WFbRcRyyP7wgfd3OJ6R+nyauvaSXqt6SJoK/AFwBxX7XMpKNs2mZe/lZ+77RsReZNXC0yR9rNMB2bt+AOxANivBcuBbnQ0nP0njgJ8AX4iIlzsdT7uVlWyeBibXrW8LLCvp2m0XEcvSz+eA68mqib3sWUmTANLP5zocz7BFxLMRsSYi3gF+SI98NpJ+hyzRXBERP02bK/O5QHnJ5i5gmqTtJW0EHAvcUNK120rSJpLeV1sG/gi4f93v6no3kE2ERvr57x2MZURq/5zJ0fTAZ6PsC5l+BCyJiAvqdlXmc4ESexCnR5DfIfuO8Esi4rxSLtxmkj5AVpqBbD6gK3vpXiRdBRxANn3Bs8BXgZ8B1wJTgCeBP4+Irm94HeJeDiCrQgXwBPDXtXaPbiVpP+A2YDHwTtp8Jlm7Tc99LkPxcAUzK4V7EJtZKZxszKwUTjZmVgonGzMrhZONmZXCyaaiJM1NX7uT9/ip9aOnu4Wkz0g6IS2fJGnrun0XV2BA7wbDj74NeHdMzo0RsXuHQxmSpPnAjIjoqWk9LOOSTQ+S9CVJp6flb0v6RVo+WNLlafkJSRNSiWWJpB+muVLmSRqbjumTdK+khcBpdecfI+nHac6eeyQdmLbPlbRHWr5H0llp+VxJf9UQ41RJD0makwZFXidp47o470nnv0TSe9P2b0h6MB1/ftp2tqQZkj4J9ANXpHlqxkqaL6k/HTc9ne9+Sd+si2O1pPPSfd4uaasCPhLLwcmmNy0APpqW+4FxaWxNrSdqo2nAzIjYDXgR+LO0/cfA6RHxkYbjTwOIiA8C04E5ksbUritpU7I5WPZNxw913Z2AWRGxB/Ay8Ll0ntnAMen8o4HPShpPNrxgt3T81+pPFBHXAQPA8Wmemtdq+1LV6pvAQWS9h/eWVJuOYRPg9oj4/RT/KU3itBI42fSmQaAvjdF6A1hIlnQ+SvN/+scjYlHde6dK2gzYPCJuTdsvqzt+v9p6RDxE9q2mO6Zzfyztv4ksyW0MTI2Ih5tc96mI+HVavjy9b6cUzyNp+5x0zpeB14GLJf0p8GreXwawNzA/IlZExNvAFemcAG8CN9bfewvntTZysulBEfEW2bifk4H/JksCB5JNrbCkyVveqFteQ1aaEENP89FsShDIBtTWktoC4B6yksLgUKE2WW967pQkPkQ28vko4OYhztlKvABvxf83TNbu3TrAyaZ3LQBmpJ+3AZ8BFkXOFv+IeBF4KQ0CBDi+4dzHA0jakWwg4MNplsWngE8Bt6frzqB5aQpgiqRaFW068CvgIbKS1e+l7Z8Gbk1zuWyWpvH8All1qNErwPuabL8D2D+1UY1K17q1yXHWQU42ves2YBKwMCKeJauCDPVPP5STgZmpgfi1uu0XAqMkLQauAU6KiFrp6Dbg2Yh4NS1vu47rLgFOlHQfMB74QUS8nq77b+n87wAXkSWRG9OxtwJ/1+R8s4GLag3EtY1pVPc/Ab8E7gXujoieno6hivzo2wrRC4/SrVwu2ZhZKVyyMbNSuGRjZqVwsjGzUjjZmFkpnGzMrBRONmZWiv8DOe86sM5fQXQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 720x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "figure(figsize=(10,4))\n",
    "subplot(1,2,1)\n",
    "title('MFE encoding')\n",
    "imshow(np.array(inputs_train[33].reshape(25,10)).T,cmap='gist_heat_r')\n",
    "yticks(range(10), ['A','U','G','C','H','E','I','M','B','S'])\n",
    "xlabel('window position')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### To set up kaggle submission format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission = pd.read_csv('sample_submission.csv.zip')\n",
    "mask = sample_submission['id_seqpos'].isin(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Model based on single MFE structure (primary DegScore model used)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting deg_Mg_pH10 ...\n"
     ]
    }
   ],
   "source": [
    "for output_type in ['deg_Mg_pH10']: #['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C','deg_50C']:\n",
    "    mea_outputs_train, mea_outputs_labels = encode_output(kaggle_train, data_type=output_type)\n",
    "\n",
    "    reg = Ridge(alpha=0.15, fit_intercept=False)\n",
    "    print('Fitting %s ...' % output_type)\n",
    "    #reg.fit(mea_inputs_train_construct, mea_outputs_train)\n",
    "    reg.fit(mea_inputs_train, mea_outputs_train)\n",
    "    \n",
    "    mea_models[output_type] = reg\n",
    "    \n",
    "    test_prediction = reg.predict(mea_inputs_test)\n",
    "    sample_submission.loc[mask, output_type] = test_prediction"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
